#!/usr/bin/env python3

"""Diff 2 JSON objects."""

from __future__ import annotations

import argparse
import json
import sys
from collections.abc import Iterator
from difflib import unified_diff
from fnmatch import fnmatchcase
from pathlib import Path

from colorama import Fore, Style

PROG = Path(__file__).stem

CONTEXT_LINES = 3


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Diff 2 JSON objects.')

    argp.add_argument(
        '-i',
        '--ignore',
        action='append',
        metavar='GLOB',
        default=[],
        help=(
            'Ignore top level fields that match the specified glob pattern.'
            ' Can be specified multiple times.'
        ),
    )

    argp.add_argument(
        '-n',
        type=int,
        metavar='LINES',
        default=CONTEXT_LINES,
        help=f'Number of context lines. Default is {CONTEXT_LINES}.',
    )

    argp.add_argument(
        'files',
        metavar='FILE',
        nargs=2,
        type=argparse.FileType('r'),
        help=(
            'Name of files containing, each containing a single JSON object.'
            ' One of the files can be - for stdin.'
        ),
    )

    args = argp.parse_args()

    if args.files[0].name == args.files[1].name == '<stdin>':
        argp.error('Only one of the input files can be stdin')
    return args


# ------------------------------------------------------------------------------
def match_any(s: str, globs: list[str], ignore_case: bool = False) -> bool:
    """
    Check if a string matches any glob pattern in a list of patterns.

    :param s:           The string to match.
    :param globs:       A list of glob style patterns.
    :param ignore_case: If True ignore case.
    :return:            True if the string matches any pattern, False otherwise.
    """

    if ignore_case:
        s = s.lower()
        globs = [g.lower() for g in globs]

    return any(fnmatchcase(s, pattern) for pattern in globs)


# ------------------------------------------------------------------------------
def json_diff(
    d1: dict, d2: dict, ignore: list[str] = None, context_lines: int = CONTEXT_LINES
) -> Iterator[str]:
    """
    Calculate a diff on two JSON serialisable objects.

    :param d1:      The first JSON object.
    :param d2:      The second JSON object.
    :param ignore:  A list of glob patterns. Top level keys matching any of the
                    glob patterns will be ignored.
    :param context_lines: Number of context lines.
    :return:        An iterator containinbg a unified diff of two JSON objects.
    """

    if not isinstance(d1, dict) or not isinstance(d2, dict):
        raise TypeError('d1 and d2 must be dicts.')

    s1 = json.dumps(
        {k: v for k, v in d1.items() if not match_any(k, ignore)}, sort_keys=True, indent=4
    )
    s2 = json.dumps(
        {k: v for k, v in d2.items() if not match_any(k, ignore)}, sort_keys=True, indent=4
    )

    return unified_diff(s1.splitlines(), s2.splitlines(), n=context_lines)


# ------------------------------------------------------------------------------
def main() -> int:
    """Show time."""

    args = process_cli_args()
    data = []
    for fp in args.files:
        try:
            data.append(json.load(fp))
        except Exception as e:
            raise Exception(f'{fp.name}: {e}')

    status = 0
    for line in json_diff(*data, ignore=args.ignore, context_lines=args.n):
        status = 1
        if line.startswith(('---', '+++')):
            continue
        if line.startswith('@@'):
            # The control lines have a trailing linefeed
            print(f'{Fore.CYAN}{line}{Style.RESET_ALL}', end='')
        elif line[0] == '-':
            print(f'{Fore.RED}{line}{Style.RESET_ALL}')
        elif line[0] == '+':
            print(f'{Fore.GREEN}{line}{Style.RESET_ALL}')
        else:
            print(line)
    return status


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
